{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b24b19b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0613,  0.1030,  0.1104,  0.0062,  0.3094,  0.0503,  0.0666, -0.1326,\n",
       "         -0.0145,  0.0998],\n",
       "        [-0.0158,  0.0448,  0.0225, -0.0139,  0.2253,  0.0265,  0.0997, -0.1417,\n",
       "          0.0835,  0.0197]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "net = nn.Sequential(\n",
    "    nn.Linear(20, 256),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(256, 10))\n",
    "\n",
    "X = torch.rand(2, 20)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e2d45bc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    # 用模型参数声明层\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(20, 256)\n",
    "        self.out = nn.Linear(256, 10)\n",
    "    \n",
    "    # 定义模型的前向传播，即如何根据输入X返回所需的模型输出\n",
    "    def forward(self, X):\n",
    "        return self.out(F.relu(self.hidden(X)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f81aab68",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-4.6609e-02, -1.4985e-02,  3.0498e-01, -4.1742e-02,  1.2712e-01,\n",
       "          3.2396e-01,  5.2925e-02, -9.4702e-02,  1.6149e-01, -6.8582e-02],\n",
       "        [-8.1018e-03, -1.2486e-02,  2.1394e-01, -1.7008e-04, -6.1914e-02,\n",
       "          2.9943e-01,  9.0445e-02, -1.1172e-01,  5.9498e-02, -4.8344e-02]],\n",
       "       grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MLP()\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3db777f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0680, -0.0864, -0.2025,  0.1701, -0.1013, -0.0672,  0.1442,  0.0058,\n",
       "          0.1022, -0.1417],\n",
       "        [ 0.0035, -0.1082, -0.1689,  0.1438, -0.1829, -0.1488,  0.1201, -0.0097,\n",
       "          0.1115, -0.1597]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MySequential(nn.Module):\n",
    "    def __init__(self, *args):\n",
    "        super().__init__()\n",
    "        for idx, module in enumerate(args):\n",
    "            self._modules[str(idx)] = module\n",
    "    def forward(self, X):\n",
    "        for block in self._modules.values():\n",
    "            X = block(X)\n",
    "        return X\n",
    "    \n",
    "net = MySequential(\n",
    "    nn.Linear(20, 256),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(256, 10)\n",
    ")\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4d4f13e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.2155, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class FixHiddenMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.rand_weight = torch.rand((20,20), requires_grad=False)\n",
    "        self.linear = nn.Linear(20, 20)\n",
    "    \n",
    "    def forward(self, X):\n",
    "        X = self.linear(X)\n",
    "        # 使用创建的常量参数以及relu和mm函数\n",
    "        X = F.relu(torch.mm(X, self.rand_weight) + 1)\n",
    "        # 复用全连接层，这相当于两个全连接层共享参数\n",
    "        X = self.linear(X)\n",
    "        # 控制流\n",
    "        while X.abs().sum() > 1:\n",
    "            X /= 2\n",
    "        return X.sum()\n",
    "net = FixHiddenMLP()\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "531810a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.2760, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(20, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU())\n",
    "        self.linear = nn.Linear(32, 16)\n",
    "    def forward(self, X):\n",
    "        return self.linear(self.net(X))\n",
    "chimera = nn.Sequential(\n",
    "    NestMLP(),\n",
    "    nn.Linear(16, 20),\n",
    "    FixHiddenMLP()\n",
    ")\n",
    "chimera(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a72f2a1c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deep-learning-pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.22"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
